# Reksio - Memory Map Editor
# Copyright (C) 2023 CERN
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# In applying this licence, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization or
# submit itself to any jurisdiction.

import reksio
import ast
import math
from functools import lru_cache
from collections import OrderedDict

NODE_SEPARATOR = '/'

CONST_PREFIXES = ['CFG_', 'CST_']


class TreeNode:
    def __init__(self, attr_name, node_parent):
        self.name = attr_name
        self.node_parent = node_parent
        self.node_lookup = node_parent
        self.children = OrderedDict()

    @lru_cache(maxsize=32)
    def lookup(self, name, sep=NODE_SEPARATOR, _going_up_permitted=True):
        if not name:
            return None

        name = str.split(name, sep)

        # look among children
        element = next((node for node in self.node_lookup.children() if str.lower(node.name) == str.lower(name[0])),
                       None)

        if element is None and _going_up_permitted:
            if self.node_lookup.parent is None:
                self.node_lookup = self.node_parent
                return None

            self.node_lookup = self.node_lookup.parent
            return self.lookup(sep.join(name))

        elif len(name) > 1:
            # root of the element that we're looking for is found.
            # lets look for the rest
            self.node_lookup = element
            return self.lookup(sep.join(name[1:]), sep, _going_up_permitted=False)

        self.node_lookup = self.node_parent
        return element


# Full AST doc: https://greentreesnakes.readthedocs.io/en/latest/nodes.html#
#
# Validator is mostly dedicated for checking the most common syntax mistakes.
#
# WARNING: For "exotic" cases, a syntax error can be NOT detected!.
class SyntaxValidator(ast.NodeVisitor):

    def __init__(self, attr_full_name, attr_name, attr_node_parent):
        self.attr_full_name = attr_full_name
        self.attr_name = attr_name
        self.attr_node_parent = attr_node_parent
        self.names = set()
        self.calls = set()

    def visit_Module(self, node):
        self.generic_visit(node)

    def visit_Name(self, node):
        self.names.add(node.id)

    def visit_Call(self, node):
        self.calls.add(node.func.id)
        self.generic_visit(node)

    def visit_Str(self, node):
        if not isinstance(node.s, str):
            pass
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden string variable "{node.s}" found!')

    def visit_AugAssign(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden assigment symbol "=" found!')

    def visit_FloorDiv(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden symbol "//" found!')

    def visit_AugAssign(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden assigment symbol "=" found!')

    def visit_FloorDiv(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden symbol "//" found!')

    def visit_Pow(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidden power symbol "**" found!')

    def visit_Starred(self, node):
        raise SyntaxError(f'In attribute "{self.attr_full_name}": forbidded variable reference symbol "*" found!')

    def visit_BinOp(self, node):
        self.generic_visit(node)

        if isinstance(node.op, ast.Mod):
            left = node.left
            right = node.right

            is_left_not_int = isinstance(left, ast.Num) and (type(left.n) is not int)
            is_right_not_int = isinstance(right, ast.Num) and (type(right.n) is not int)

            if is_left_not_int or is_right_not_int:
                raise SyntaxError(
                    f'''Modulo operation "%" is used for float! - in C++, there is no modulo operation for floats! 
                    For floats, instead of modulo, the "fmod()" should be used.''')

            if isinstance(left, ast.Name):
                reksio.warn(f'In attribute "{self.attr_full_name}": for modulo operation "%", '
                                     f'left variable/value "{left.id}"must be integer!')

            if isinstance(right, ast.Name):
                reksio.warn(f'In attribute "{self.attr_full_name}": for modulo operation "%", '
                                     f'right variable/value "{right.id}" must be integer!')

    def visit_Attribute(self, node):
        '''
        Catches an attribute, like 'a.b.c',
        starting from c (node.attr), b(node.value.attr), a(node.value.value.id - Name)
        '''
        factor_names = [node.attr]
        node_val = node.value

        while isinstance(node_val, ast.Attribute):
            factor_names.insert(0, node_val.attr)
            node_val = node_val.value

        factor_names.insert(0, node_val.id)
        attr_name = NODE_SEPARATOR.join(factor_names)
        self.names.add(attr_name)

    def check_funcalls(self):
        valid = True

        # additional supported functions, which do not belong to python math lib
        extra_funcs = ('abs', 'round')

        for funcName in self.calls:
            if funcName in extra_funcs:
                continue

            if hasattr(math, funcName):
                continue

            reksio.warn(f'In attribute "{self.attr_full_name}":'
                    f' function "{funcName}" is not supported')
            valid = False

        return valid

    # Get all constants from all parents of node
    def get_constants(self):
        root_map = self.attr_node_parent.getRoot().getChildByType('memory-map')
        found_consts = set()
        parent_node = self.attr_node_parent

        while True:
            xdrvedge_consts = parent_node.getChildByType(['x-driver-edge', 'constant-values'])
            xfesa_consts = parent_node.getChildByType(['x-fesa', 'configuration-values'])
            
            if xdrvedge_consts:
                found_consts.update([child.name for child in xdrvedge_consts.children()])

            if xfesa_consts:
                found_consts.update([child.name for child in xfesa_consts.children()])

            if parent_node == root_map:
                break

            parent_node = parent_node.parent
        
        return found_consts

    def check_vars(self):
        attr_node = TreeNode(self.attr_name, self.attr_node_parent)
        vars_found_flag = True

        vars_names = self.names - self.calls
        if vars_names is None:
            return True

        const_names = set([name for name in vars_names if any(prefix in name for prefix in CONST_PREFIXES)])
        vars_names = vars_names - const_names

        if 'val' not in self.names:
            reksio.warn(f'In attribute "{self.attr_full_name}": value("val") not used!')
            vars_found_flag = False
        else:
            vars_names.remove('val')

        for var in vars_names:
            is_var_found = attr_node.lookup(var)
            if is_var_found is None:
                reksio.warn(f'In attribute "{self.attr_full_name}": variable "{var}" not found in map!')
                vars_found_flag = False

        # Used constants are tested by checking their presence in current map "x-driver-edge" and "x-fesa" children.
        # If an used constant is defined in other map or submap, a warning msg will be printed.
        if const_names is None:
            return vars_found_flag

        found_constants = self.get_constants()

        for var in const_names:
            if var not in found_constants:
                reksio.warn(f'In attribute "{self.attr_full_name}": constant "{var}" not found '
                                     f'in parents (check "x-driver-edge" or "x-fesa" constants)!')
                vars_found_flag = False

        return vars_found_flag


def get_attr_full_name(attr):
    # E.g. for "map/block/reg1/value"
    # 'nodeLocation()' returns "['map', 'block', 'reg1', 'value']"
    locations = attr.parentNode.nodeLocation()

    return NODE_SEPARATOR.join(locations)


def validate(attr):
    attr_name = attr.name
    attr_full_name = get_attr_full_name(attr)

    convfactor_text = attr.value
    validator = SyntaxValidator(attr_full_name, attr_name, attr.parentNode)

    vars_ok_flag = True
    funcalls_ok_flag = True

    try:
        validator.visit(ast.parse(convfactor_text))

        funcalls_ok_flag = validator.check_funcalls()
        vars_ok_flag = validator.check_vars()

    except (SyntaxError, AttributeError) as errMsg:
        reksio.warn(f'Syntax error in "{attr_full_name}": "{convfactor_text}", error log: "{errMsg}"')
        return False

    return vars_ok_flag and funcalls_ok_flag


def getMessage():
    return 'Syntax error: "x-conversions" attribute not valid!'
